{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "acd7b15e",
   "metadata": {},
   "source": [
    "# Dreambooth with OFT\n",
    "This Notebook assumes that you already ran the train_dreambooth.py script to create your own adapter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acab479f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import DiffusionPipeline\n",
    "from diffusers.utils import check_min_version, get_logger\n",
    "from peft import PeftModel\n",
    "\n",
    "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n",
    "check_min_version(\"0.10.0.dev0\")\n",
    "\n",
    "logger = get_logger(__name__)\n",
    "\n",
    "BASE_MODEL_NAME = \"stabilityai/stable-diffusion-2-1-base\"\n",
    "ADAPTER_MODEL_PATH = \"INSERT MODEL PATH HERE\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    BASE_MODEL_NAME,\n",
    ")\n",
    "pipe.to(\"cuda\")\n",
    "pipe.unet = PeftModel.from_pretrained(pipe.unet, ADAPTER_MODEL_PATH + \"/unet\", adapter_name=\"default\")\n",
    "pipe.text_encoder = PeftModel.from_pretrained(\n",
    "    pipe.text_encoder, ADAPTER_MODEL_PATH + \"/text_encoder\", adapter_name=\"default\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"A photo of a sks dog\"\n",
    "image = pipe(\n",
    "    prompt,\n",
    "    num_inference_steps=50,\n",
    "    height=512,\n",
    "    width=512,\n",
    ").images[0]\n",
    "image"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
